Hierarchical models

Theory: what is a hierarchical model?

In general: a model with hyperparameters, i.e. parameters that probabilistically control other parameters.

E.g.

\[\begin{align*} y_i &\sim N(\alpha_{group(i)} + \beta \cdot x_i, \sigma) \\ \alpha{group(i)} &\sim N(\mu, \tau) \end{align*}\]

In this model \(\tau\) is a hyperparameter. Is \(\mu\) a hyperparameter???

Hierarchical models are great for describing the situation where you know some measurements have something in common (e.g. they come from the same group), but you don’t know how much.

Learn more!

Example: always be closing!

Plushycorp employs 10 salespeople who go door to door selling cute plushies. The number of plushies that each salesperson sold every working day for two weeks was recorded. What can Plushycorp find out from this data?

To answer the question in a best-case scenario, we can use a hierarchical model to run a “digital twin” of this experiment with known parameters and data generating process. Specifically, we can assume that the number \(y_{ij}\) of plushies that salesperson \(i\) sells on day \(j\) depends on a combination of factors:

  • The baseline amount \(\mu\) that a totally average salesperson would sell on a normal day
  • The salesperson’s ability \(ability_i\)
  • An effect \(day\ effect_j\) for the day of the week: people are thought to buy fewer and fewer plushies as the week drags on.
  • Some random variation

A good first step for modelling count data is the Poisson distribution, so let’s assume that the sales measurements follow the following Poisson distribution:1

1 Note the use of the log link function.

\[\begin{align*} y_{ij} &\sim Poisson(\lambda) \\ \ln\lambda &= \mu + ability_i + day\ effect_j \end{align*}\]

We know that the salespeople have different abilities, but how just different are they? Since this isn’t really clear to Plushycorp, it makes sense to introduce a parameter \(\tau_{ability}\) into the model:

\[\begin{equation*} ability \sim N(0, \tau^{ability}) \end{equation*}\]

Now we have a hierarchical model!

We can make a similar argument for the day of the week effects:2

2 Can you think of a better model for day effects given the information above??

\[\begin{equation*} day\ effect \sim N(0, \tau^{day}) \end{equation*}\]

Finally we can complete our model by specifying prior distributions for the non-hierarchical parameters:3

3 \(HN\) here refers to the “half-normal” distribution, a decent default prior for hierarchical standard deviations

\[\begin{align*} \mu &\sim LN(0, 1) \\ \tau_ability &\sim HN(0, 1) \\ \tau_day &\sim HN(0, 1) \end{align*}\]

To test out our model with fake data, we can use Python to generate a fake set of salespeople and days, then generate some sales consistently with our model. Next we can generate some data,

from pathlib import Path
import json
import numpy as np
import pandas as pd

N_SALESPERSON = 10
N_WEEK = 2
DAY_NAMES = ["Mon", "Tue", "Wed", "Thu", "Fri"]
BASELINE = 2  # 2 plushies in one day is fine
TAU_ABILITY = 0.35
TAU_DAY = 0.2

SEED = 12345
DATA_DIR = Path("../data")

rng = np.random.default_rng(seed=SEED)

with open(DATA_DIR / "names.json", "r") as f:
    name_directory = json.load(f)

names = [
    f"{first_name} {surname}"
    for first_name, surname in zip(
        *map(
            lambda l: rng.choice(l, size=N_SALESPERSON, replace=False),
            name_directory.values()
        )
    )
]

abilities = rng.normal(loc=0, scale=TAU_ABILITY, size=N_SALESPERSON)

salespeople = pd.DataFrame({"salesperson": names, "ability": abilities})

salespeople
salesperson ability
0 Morten Andersen 0.489643
1 Lene Poulsen 0.462804
2 Rasmus Jensen -0.104894
3 Hanne Madsen 0.316022
4 Mette Rasmussen -0.567554
5 Christian Christensen -0.055366
6 Helle Kristensen 0.157319
7 Charlotte Hansen -0.470260
8 Maria Petersen -0.028591
9 Jette Thomsen 0.603659
day_effects = sorted(
    rng.normal(loc=0, scale=TAU_DAY, size=len(DAY_NAMES))
)[::-1]  # This (i.e. `[::-1]`) is a nice way to reverse a list
days = pd.DataFrame({"day": DAY_NAMES, "day_effect": day_effects})
days
day day_effect
0 Mon 0.523632
1 Tue 0.165727
2 Wed 0.155472
3 Thu -0.191798
4 Fri -0.241878
sales = (
    days
    .merge(salespeople, how="cross")
    .merge(pd.DataFrame({"week":[1, 2, 3, 4]}), how="cross")
    .assign(
        sales=lambda df: rng.poisson(
            np.exp(np.log(BASELINE) + df["ability"] + df["day_effect"])
        )
    )
    [["week", "day", "salesperson", "day_effect", "ability", "sales"]]
    .copy()
)
sales.head()
week day salesperson day_effect ability sales
0 1 Mon Morten Andersen 0.523632 0.489643 10
1 2 Mon Morten Andersen 0.523632 0.489643 3
2 3 Mon Morten Andersen 0.523632 0.489643 4
3 4 Mon Morten Andersen 0.523632 0.489643 4
4 1 Mon Lene Poulsen 0.523632 0.462804 4

Here is the fortnightly sales chart

total_sales = (
    sales.groupby("salesperson")["sales"].sum().sort_values(ascending=False)
)

total_sales.plot(kind="bar", ylabel="Plushies sold", title="Fortnightly sales")

It’s pretty straightforward to represent hierarchical models with Stan, almost like Stan was designed for it!

from cmdstanpy import CmdStanModel

model = CmdStanModel(stan_file="../src/stan/plushies.stan")
print(model.code())
data {
 int<lower=1> N;
 int<lower=1> N_salesperson;
 int<lower=1> N_day;
 array[N] int<lower=1,upper=N_salesperson> salesperson;
 array[N] int<lower=1,upper=N_day> day;
 array[N] int<lower=0> sales;
 int<lower=0,upper=1> likelihood;
}
parameters {
 real log_mu;
 vector[N_salesperson] ability;
 vector[N_day] day_effect;
 real<lower=0> tau_ability;
 real<lower=0> tau_day;
}
transformed parameters {
 vector[N] log_lambda = log_mu + ability[salesperson] + day_effect[day]; 
}
model {
  log_mu ~ normal(0, 1);
  ability ~ normal(0, tau_ability);
  day_effect ~ normal(0, tau_day);
  tau_ability ~ normal(0, 0.5);
  tau_day ~ normal(0, 0.5);
  if (likelihood){
    sales ~ poisson_log(log_lambda);
  }
}
generated quantities {
 real mu = exp(log_mu);
 vector[N] lambda = exp(log_lambda);
 array[N] int yrep = poisson_rng(lambda);
 vector[N] llik; 
 for (n in 1:N){
   llik[n] = poisson_lpmf(sales[n] | lambda[n]);
 }
}

import arviz as az
from stanio.json import process_dictionary

def one_encode(l):
    """One-encode a 1d list-like thing."""
    return dict(zip(l, range(1, len(l) + 1)))


salesperson_codes = one_encode(salespeople["salesperson"])
day_codes = one_encode(days["day"])
data_prior = process_dictionary({
        "N": len(sales),
        "N_salesperson": len(salespeople),
        "N_day": len(days),
        "salesperson": sales["salesperson"].map(salesperson_codes),
        "day": sales["day"].map(day_codes),
        "sales": sales["sales"],
        "likelihood": 0
    }
)
data_posterior = data_prior | {"likelihood": 1}
mcmc_prior = model.sample(data=data_prior)
mcmc_posterior = model.sample(data=data_posterior)
idata = az.from_cmdstanpy(
    posterior=mcmc_posterior,
    prior=mcmc_prior,
    log_likelihood="llik",
    posterior_predictive="yrep",
    observed_data=data_posterior,
    coords={
        "salesperson": salespeople["salesperson"],
        "day": days["day"],
        "observation": sales.index
    },
    dims={
        "lambda": ["observation"],
        "ability": ["salesperson"],
        "day_effect": ["day"],
        "llik": ["observation"],
        "yrep": ["observation"]
    }
)
idata
17:23:34 - cmdstanpy - INFO - CmdStan start processing
17:23:35 - cmdstanpy - INFO - CmdStan done processing.
17:23:35 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 23, column 2 to column 34)
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 22, column 2 to column 35)
Consider re-running with show_console=True if the above output is unclear!
17:23:35 - cmdstanpy - WARNING - Some chains may have failed to converge.
    Chain 1 had 19 divergent transitions (1.9%)
    Chain 2 had 1 divergent transitions (0.1%)
    Chain 3 had 17 divergent transitions (1.7%)
    Chain 4 had 4 divergent transitions (0.4%)
    Use the "diagnose()" method on the CmdStanMCMC object to see further information.
17:23:35 - cmdstanpy - INFO - CmdStan start processing
17:23:35 - cmdstanpy - INFO - CmdStan done processing.
                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                
arviz.InferenceData
    • <xarray.Dataset> Size: 13MB
      Dimensions:           (chain: 4, draw: 1000, salesperson: 10, day: 5,
                             log_lambda_dim_0: 200, observation: 200)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * salesperson       (salesperson) object 80B 'Morten Andersen' ... 'Jette T...
        * day               (day) object 40B 'Mon' 'Tue' 'Wed' 'Thu' 'Fri'
        * log_lambda_dim_0  (log_lambda_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * observation       (observation) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
      Data variables:
          log_mu            (chain, draw) float64 32kB 0.3217 0.7637 ... 0.7574 0.5949
          ability           (chain, draw, salesperson) float64 320kB 0.6564 ... 0.7226
          day_effect        (chain, draw, day) float64 160kB 0.5388 0.2417 ... 0.01487
          tau_ability       (chain, draw) float64 32kB 0.2651 0.6149 ... 0.5042 0.2755
          tau_day           (chain, draw) float64 32kB 0.7837 1.093 ... 0.315 0.3572
          log_lambda        (chain, draw, log_lambda_dim_0) float64 6MB 1.517 ... 1...
          mu                (chain, draw) float64 32kB 1.379 2.146 ... 2.133 1.813
          lambda            (chain, draw, observation) float64 6MB 4.558 ... 3.79
      Attributes:
          created_at:                 2024-04-24T15:23:36.080006
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 6MB
      Dimensions:      (chain: 4, draw: 1000, observation: 200)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observation  (observation) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
      Data variables:
          yrep         (chain, draw, observation) float64 6MB 5.0 8.0 2.0 ... 3.0 4.0
      Attributes:
          created_at:                 2024-04-24T15:23:36.085663
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 6MB
      Dimensions:      (chain: 4, draw: 1000, observation: 200)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observation  (observation) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
      Data variables:
          llik         (chain, draw, observation) float64 6MB -4.493 -1.799 ... -1.818
      Attributes:
          created_at:                 2024-04-24T15:23:36.417443
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB -6.824 -4.298 ... 0.8808 -6.163
          acceptance_rate  (chain, draw) float64 32kB 0.9944 0.7808 ... 0.9857 0.8948
          step_size        (chain, draw) float64 32kB 0.2429 0.2429 ... 0.2209 0.2209
          tree_depth       (chain, draw) int64 32kB 3 5 5 4 5 4 4 4 ... 4 4 5 3 4 4 4
          n_steps          (chain, draw) int64 32kB 7 47 47 15 31 ... 31 15 15 31 15
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 11.71 22.77 15.0 ... 4.411 14.17
      Attributes:
          created_at:                 2024-04-24T15:23:36.083541
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 26MB
      Dimensions:           (chain: 4, draw: 1000, salesperson: 10, day: 5,
                             log_lambda_dim_0: 200, observation: 200)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * salesperson       (salesperson) object 80B 'Morten Andersen' ... 'Jette T...
        * day               (day) object 40B 'Mon' 'Tue' 'Wed' 'Thu' 'Fri'
        * log_lambda_dim_0  (log_lambda_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * observation       (observation) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
      Data variables:
          log_mu            (chain, draw) float64 32kB -0.03402 1.136 ... -0.8584
          ability           (chain, draw, salesperson) float64 320kB 0.2521 ... -0....
          day_effect        (chain, draw, day) float64 160kB -0.5265 ... -0.6955
          tau_ability       (chain, draw) float64 32kB 0.5254 0.4491 ... 0.1053 0.1874
          tau_day           (chain, draw) float64 32kB 0.3982 0.6697 ... 0.674 0.4551
          log_lambda        (chain, draw, log_lambda_dim_0) float64 6MB -0.3084 ......
          mu                (chain, draw) float64 32kB 0.9665 3.114 ... 0.3585 0.4238
          lambda            (chain, draw, observation) float64 6MB 0.7346 ... 0.1846
          yrep              (chain, draw, observation) float64 6MB 1.0 0.0 ... 0.0 1.0
          llik              (chain, draw, observation) float64 6MB -18.92 ... -4.257
      Attributes:
          created_at:                 2024-04-24T15:23:36.409954
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB -5.509 2.78 5.74 ... 14.99 13.16
          acceptance_rate  (chain, draw) float64 32kB 0.9446 0.9892 ... 0.5428 0.2292
          step_size        (chain, draw) float64 32kB 0.2268 0.2268 ... 0.2186 0.2186
          tree_depth       (chain, draw) int64 32kB 4 4 4 3 4 5 4 4 ... 4 3 4 4 4 4 3
          n_steps          (chain, draw) int64 32kB 15 15 15 7 31 ... 15 15 15 31 15
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 13.2 11.27 ... -8.691 -4.546
      Attributes:
          created_at:                 2024-04-24T15:23:36.413130
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 10kB
      Dimensions:              (N_dim_0: 1, N_salesperson_dim_0: 1, N_day_dim_0: 1,
                                salesperson_dim_0: 200, day_dim_0: 200,
                                sales_dim_0: 200, likelihood_dim_0: 1)
      Coordinates:
        * N_dim_0              (N_dim_0) int64 8B 0
        * N_salesperson_dim_0  (N_salesperson_dim_0) int64 8B 0
        * N_day_dim_0          (N_day_dim_0) int64 8B 0
        * salesperson_dim_0    (salesperson_dim_0) int64 2kB 0 1 2 3 ... 197 198 199
        * day_dim_0            (day_dim_0) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
        * sales_dim_0          (sales_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * likelihood_dim_0     (likelihood_dim_0) int64 8B 0
      Data variables:
          N                    (N_dim_0) int64 8B 200
          N_salesperson        (N_salesperson_dim_0) int64 8B 10
          N_day                (N_day_dim_0) int64 8B 5
          salesperson          (salesperson_dim_0) int64 2kB 1 1 1 1 2 ... 10 10 10 10
          day                  (day_dim_0) int64 2kB 1 1 1 1 1 1 1 1 ... 5 5 5 5 5 5 5
          sales                (sales_dim_0) int64 2kB 10 3 4 4 4 5 6 ... 4 1 1 3 3 2
          likelihood           (likelihood_dim_0) int64 8B 1
      Attributes:
          created_at:                 2024-04-24T15:23:36.415285
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

az.summary(idata, var_names="~lambda", filter_vars="regex")
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
log_mu 0.793 0.214 0.389 1.185 0.007 0.006 856.0 1075.0 1.0
ability[Morten Andersen] 0.440 0.176 0.104 0.763 0.005 0.004 1343.0 1386.0 1.0
ability[Lene Poulsen] 0.376 0.175 0.036 0.696 0.005 0.004 1332.0 1278.0 1.0
ability[Rasmus Jensen] -0.164 0.188 -0.503 0.208 0.005 0.004 1568.0 1593.0 1.0
ability[Hanne Madsen] 0.105 0.183 -0.243 0.449 0.005 0.004 1476.0 1661.0 1.0
ability[Mette Rasmussen] -0.543 0.213 -0.942 -0.147 0.005 0.004 1985.0 1735.0 1.0
ability[Christian Christensen] -0.122 0.192 -0.476 0.253 0.005 0.004 1601.0 1829.0 1.0
ability[Helle Kristensen] 0.185 0.183 -0.153 0.537 0.005 0.004 1407.0 1320.0 1.0
ability[Charlotte Hansen] -0.278 0.198 -0.661 0.091 0.005 0.004 1710.0 1663.0 1.0
ability[Maria Petersen] -0.207 0.193 -0.582 0.146 0.005 0.004 1610.0 1781.0 1.0
ability[Jette Thomsen] 0.412 0.177 0.103 0.780 0.005 0.004 1307.0 1547.0 1.0
day_effect[Mon] 0.326 0.186 -0.001 0.662 0.006 0.004 1096.0 1317.0 1.0
day_effect[Tue] 0.105 0.188 -0.224 0.456 0.006 0.006 1121.0 1287.0 1.0
day_effect[Wed] 0.113 0.184 -0.221 0.447 0.006 0.005 1122.0 1302.0 1.0
day_effect[Thu] -0.322 0.195 -0.665 0.032 0.006 0.006 1226.0 1184.0 1.0
day_effect[Fri] -0.193 0.191 -0.533 0.154 0.006 0.006 1218.0 1271.0 1.0
tau_ability 0.395 0.116 0.209 0.617 0.003 0.002 2167.0 2615.0 1.0
tau_day 0.344 0.155 0.123 0.632 0.004 0.003 2271.0 1900.0 1.0
mu 2.263 0.515 1.407 3.142 0.019 0.015 856.0 1075.0 1.0

The problem with hierarchical models: funnels

Did you notice that cmdstanpy printed some divergent transition warnings above? This illustrates a pervasive problem with hierarchical models: funnel-shaped marginal posterior distributions. The plot below shows the values of the parameter \(\tau_{day}\) and the corresponding day effect values for Monday in the prior samples:

az.plot_pair(
    idata.prior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
);
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)

As we discussed previously, funnels are hard to sample because of their inconsistent characteristic lengths. Unfortunately, they are often inevitable in hierarchical models. Do you get an idea why from the graph?

There are three main solutions to funnels: add more information, tune the HMC algorithm or reparameterise the model.

Add more information

The posterior distribution didn’t have any divergent transitions. This is probably because the extra information in the measurements made it easier to sample. Comparing the marginal distributions from above illustrates how this can happen: note that the difference in scale between the neck and the bowl of the funnel is less extreme for the posterior samples.

from matplotlib import pyplot as plt
f, ax = plt.subplots()
az.plot_pair(
    idata.prior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
    ax=ax,
    scatter_kwargs={"label": "prior"},
);
az.plot_pair(
    idata.posterior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
    ax=ax,
    scatter_kwargs={"label": "posterior"},
);
ax.legend(frameon=False);
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)

If better measurements aren’t available, divergences can often be avoided by searching for extra information that can justify narrower priors.

Tune the algorithm

Stan allows increasing the length of the warmup phase (iter_warmup, default 2000), bringing the target acceptance probability close to 1 (adapt_delta, default 0.8) and by increasing the leapfrog integrator’s maximum tree depth (max_treedepth, default 10). All of these changes trade speed for reliability.

mcmc_prior_2 = model.sample(
    data=data_prior,
    iter_warmup=3000,
    adapt_delta=0.99,
    max_treedepth=12
)
17:23:36 - cmdstanpy - INFO - CmdStan start processing
17:23:42 - cmdstanpy - INFO - CmdStan done processing.
17:23:42 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 23, column 2 to column 34)
Consider re-running with show_console=True if the above output is unclear!
17:23:42 - cmdstanpy - WARNING - Some chains may have failed to converge.
    Chain 4 had 1 divergent transitions (0.1%)
    Use the "diagnose()" method on the CmdStanMCMC object to see further information.
                                                                                                                                                                                                                                                                                                                                

Unfortunately even quite aggressive tuning doesn’t get rid of all the divergent transitions in this case.

Reparameterise

The idea with reparameterisation is to define auxiliary parameters which don’t have problematic relationships, then recover the problematic parameters later.

“Non-centred” parameterisations take a distribution with the form \(\alpha\sim D(\mu,\sigma)\) and express it as follows:

\[\begin{align*} u \sim D(0, 1)\\ \alpha = \mu + u * \sigma \end{align*}\]

model_nc = CmdStanModel(stan_file="../src/stan/plushies-nc.stan")
print(model_nc.code())
data {
 int<lower=1> N;
 int<lower=1> N_salesperson;
 int<lower=1> N_day;
 array[N] int<lower=1,upper=N_salesperson> salesperson;
 array[N] int<lower=1,upper=N_day> day;
 array[N] int<lower=0> sales;
 int<lower=0,upper=1> likelihood;
}
parameters {
 real log_mu;
 vector[N_salesperson] ability_z;
 vector[N_day] day_effect_z;
 real<lower=0> tau_ability;
 real<lower=0> tau_day;
}
transformed parameters {
 vector[N_salesperson] ability = ability_z * tau_ability;
 vector[N_day] day_effect = day_effect_z * tau_day;
 vector[N] log_lambda = log_mu + ability[salesperson] + day_effect[day]; 
}
model {
  log_mu ~ normal(0, 1);
  ability_z ~ normal(0, 1);
  day_effect_z ~ normal(0, 1);
  tau_ability ~ normal(0, 1);
  tau_day ~ normal(0, 1);
  if (likelihood){
    sales ~ poisson_log(log_lambda);
  }
}
generated quantities {
 real mu = exp(log_mu);
 vector[N] lambda = exp(log_lambda);
 array[N] int yrep = poisson_rng(lambda);
 vector[N] llik; 
 for (n in 1:N){
   llik[n] = poisson_lpmf(sales[n] | lambda[n]);
 }
}
mcmc_prior_nc = model.sample(
    data=data_prior,
    iter_warmup=3000,
    adapt_delta=0.999,
    max_treedepth=12
)
17:23:42 - cmdstanpy - INFO - CmdStan start processing
17:23:54 - cmdstanpy - INFO - CmdStan done processing.
17:23:54 - cmdstanpy - WARNING - Some chains may have failed to converge.
    Chain 2 had 2 divergent transitions (0.2%)
    Chain 3 had 1 divergent transitions (0.1%)
    Use the "diagnose()" method on the CmdStanMCMC object to see further information.
                                                                                                                                                                                                                                                                                                                                

Beware of using non-centred parameterisation as a default: it isn’t guaranteed to be better.

So how many plushies do I need to sell?

f, ax = plt.subplots()
az.plot_forest(
    np.exp(idata.posterior["log_mu"] + idata.posterior["ability"]),
    kind="forestplot",
    combined=True,
    ax=ax,
    show=False,
);
ax.scatter(
    np.exp(np.log(BASELINE) + salespeople["ability"]), 
    ax.get_yticks()[::-1], 
    color="red", 
    label="True expected sales",
    zorder=2
)
ax.scatter(
    sales.groupby("salesperson")["sales"].mean().reindex(salespeople["salesperson"]), 
    ax.get_yticks()[::-1], 
    color="black", 
    label="Observed sales per day",
    zorder=3
)
ax.set(title="", xlabel="Number of plushies sold per day")
ax.axvline(BASELINE, linestyle="--", label="baseline", linewidth=0.8, color="black")
ax.legend(frameon=False);

Takeaways

  • Hierarchical models are a powerful way to capture structural information
  • You may run into problematic sampling, but you have options!
  • There is surprisingly little information in low-expected-value count data.